package com.erbadagang.mybatis.plus.tenant.config;

import com.baomidou.mybatisplus.core.parser.ISqlParser;
import com.baomidou.mybatisplus.extension.plugins.PaginationInterceptor;
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantHandler;
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantSqlParser;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.ArrayList;
import java.util.List;

/**
 * @description MyBatisPlus配置类，分页插件，多租户也是使用的分页插件进行的配置。
 * @ClassName: MyBatisPlusConfig
 * @author: 郭秀志 jbcode@126.com
 * @date: 2020/7/12 21:34
 * @Copyright:
 */
@Configuration
@MapperScan("com.erbadagang.mybatis.plus.tenant.mapper")//配置扫描的mapper包
public class MyBatisPlusConfig {

    @Autowired
    private ApiContext apiContext;

    /**
     * 分页插件
     *
     * @return
     */
    @Bean
    public PaginationInterceptor paginationInterceptor() {
        PaginationInterceptor paginationInterceptor = new PaginationInterceptor();

        // 创建SQL解析器集合
        List<ISqlParser> sqlParserList = new ArrayList<>();

        // 创建租户SQL解析器
        TenantSqlParser tenantSqlParser = new TenantSqlParser();

        // 设置租户处理器
        tenantSqlParser.setTenantHandler(new TenantHandler() {

            // 设置当前租户ID，实际情况你可以从cookie、或者缓存中拿都行
            @Override
            public Expression getTenantId(boolean select) {
                // 从当前系统上下文中取出当前请求的服务商ID，通过解析器注入到SQL中。
                Long currentProviderId = apiContext.getCurrentTenantId();
                if (null == currentProviderId) {
                    throw new RuntimeException("Get CurrentProviderId error.");
                }
                return new LongValue(currentProviderId);
            }

            @Override
            public String getTenantIdColumn() {
                // 对应数据库中租户ID的列名
                return "tenant_id";
            }

            @Override
            public boolean doTableFilter(String tableName) {
                // 是否需要需要过滤某一张表
              /*  List<String> tableNameList = Arrays.asList("sys_user");
                if (tableNameList.contains(tableName)){
                    return true;
                }*/
                return false;
            }
        });

        sqlParserList.add(tenantSqlParser);
        paginationInterceptor.setSqlParserList(sqlParserList);

        //有部分SQL不需要加上租户ID的表示，需要过滤特定的sql。如果比较多不建议这里配置。
        /*paginationInterceptor.setSqlParserFilter(new ISqlParserFilter() {
            @Override
            public boolean doFilter(MetaObject metaObject) {
                MappedStatement ms = SqlParserHelper.getMappedStatement(metaObject);
                // 对应Mapper或者dao中的方法
                if("com.erbadagang.mybatis.plus.tenant.mapper.UserMapper.selectList".equals(ms.getId())){
                    return true;
                }
                return false;
            }
        });*/
        return paginationInterceptor;
    }

}